{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Mixture of Agents\n",
    "\n",
    "[Mixture of Agents](https://arxiv.org/abs/2406.04692) is a multi-agent design pattern\n",
    "that models after the feed-forward neural network architecture.\n",
    "\n",
    "The pattern consists of two types of agents: worker agents and a single orchestrator agent.\n",
    "Worker agents are organized into multiple layers, with each layer consisting of a fixed number of worker agents.\n",
    "Messages from the worker agents in a previous layer are concatenated and sent to\n",
    "all the worker agents in the next layer.\n",
    "\n",
    "This example implements the Mixture of Agents pattern using the core library\n",
    "following the [original implementation](https://github.com/togethercomputer/moa) of multi-layer mixture of agents.\n",
    "\n",
    "Here is a high-level procedure overview of the pattern:\n",
    "1. The orchestrator agent takes input a user task and first dispatches it to the worker agents in the first layer.\n",
    "2. The worker agents in the first layer process the task and return the results to the orchestrator agent.\n",
    "3. The orchestrator agent then synthesizes the results from the first layer and dispatches an updated task with the previous results to the worker agents in the second layer.\n",
    "4. The process continues until the final layer is reached.\n",
    "5. In the final layer, the orchestrator agent aggregates the results from previous layer and returns a single final result to the user.\n",
    "\n",
    "We use the direct messaging API {py:meth}`~autogen_core.BaseAgent.send_message` to implement this pattern.\n",
    "This makes it easier to add more features like worker task cancellation and error handling in the future."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import asyncio\n",
    "from dataclasses import dataclass\n",
    "from typing import List\n",
    "\n",
    "from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler\n",
    "from autogen_core.models import ChatCompletionClient, SystemMessage, UserMessage\n",
    "from autogen_ext.models.openai import OpenAIChatCompletionClient"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Message Protocol\n",
    "\n",
    "The agents communicate using the following messages:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class WorkerTask:\n",
    "    task: str\n",
    "    previous_results: List[str]\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class WorkerTaskResult:\n",
    "    result: str\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class UserTask:\n",
    "    task: str\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class FinalResult:\n",
    "    result: str"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Worker Agent\n",
    "\n",
    "Each worker agent receives a task from the orchestrator agent and processes them\n",
    "indepedently.\n",
    "Once the task is completed, the worker agent returns the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class WorkerAgent(RoutedAgent):\n",
    "    def __init__(\n",
    "        self,\n",
    "        model_client: ChatCompletionClient,\n",
    "    ) -> None:\n",
    "        super().__init__(description=\"Worker Agent\")\n",
    "        self._model_client = model_client\n",
    "\n",
    "    @message_handler\n",
    "    async def handle_task(self, message: WorkerTask, ctx: MessageContext) -> WorkerTaskResult:\n",
    "        if message.previous_results:\n",
    "            # If previous results are provided, we need to synthesize them to create a single prompt.\n",
    "            system_prompt = \"You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.\\n\\nResponses from models:\"\n",
    "            system_prompt += \"\\n\" + \"\\n\\n\".join([f\"{i+1}. {r}\" for i, r in enumerate(message.previous_results)])\n",
    "            model_result = await self._model_client.create(\n",
    "                [SystemMessage(content=system_prompt), UserMessage(content=message.task, source=\"user\")]\n",
    "            )\n",
    "        else:\n",
    "            # If no previous results are provided, we can simply pass the user query to the model.\n",
    "            model_result = await self._model_client.create([UserMessage(content=message.task, source=\"user\")])\n",
    "        assert isinstance(model_result.content, str)\n",
    "        print(f\"{'-'*80}\\nWorker-{self.id}:\\n{model_result.content}\")\n",
    "        return WorkerTaskResult(result=model_result.content)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Orchestrator Agent\n",
    "\n",
    "The orchestrator agent receives tasks from the user and distributes them to the worker agents,\n",
    "iterating over multiple layers of worker agents. Once all worker agents have processed the task,\n",
    "the orchestrator agent aggregates the results and publishes the final result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class OrchestratorAgent(RoutedAgent):\n",
    "    def __init__(\n",
    "        self,\n",
    "        model_client: ChatCompletionClient,\n",
    "        worker_agent_types: List[str],\n",
    "        num_layers: int,\n",
    "    ) -> None:\n",
    "        super().__init__(description=\"Aggregator Agent\")\n",
    "        self._model_client = model_client\n",
    "        self._worker_agent_types = worker_agent_types\n",
    "        self._num_layers = num_layers\n",
    "\n",
    "    @message_handler\n",
    "    async def handle_task(self, message: UserTask, ctx: MessageContext) -> FinalResult:\n",
    "        print(f\"{'-'*80}\\nOrchestrator-{self.id}:\\nReceived task: {message.task}\")\n",
    "        # Create task for the first layer.\n",
    "        worker_task = WorkerTask(task=message.task, previous_results=[])\n",
    "        # Iterate over layers.\n",
    "        for i in range(self._num_layers - 1):\n",
    "            # Assign workers for this layer.\n",
    "            worker_ids = [\n",
    "                AgentId(worker_type, f\"{self.id.key}/layer_{i}/worker_{j}\")\n",
    "                for j, worker_type in enumerate(self._worker_agent_types)\n",
    "            ]\n",
    "            # Dispatch tasks to workers.\n",
    "            print(f\"{'-'*80}\\nOrchestrator-{self.id}:\\nDispatch to workers at layer {i}\")\n",
    "            results = await asyncio.gather(*[self.send_message(worker_task, worker_id) for worker_id in worker_ids])\n",
    "            print(f\"{'-'*80}\\nOrchestrator-{self.id}:\\nReceived results from workers at layer {i}\")\n",
    "            # Prepare task for the next layer.\n",
    "            worker_task = WorkerTask(task=message.task, previous_results=[r.result for r in results])\n",
    "        # Perform final aggregation.\n",
    "        print(f\"{'-'*80}\\nOrchestrator-{self.id}:\\nPerforming final aggregation\")\n",
    "        system_prompt = \"You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.\\n\\nResponses from models:\"\n",
    "        system_prompt += \"\\n\" + \"\\n\\n\".join([f\"{i+1}. {r}\" for i, r in enumerate(worker_task.previous_results)])\n",
    "        model_result = await self._model_client.create(\n",
    "            [SystemMessage(content=system_prompt), UserMessage(content=message.task, source=\"user\")]\n",
    "        )\n",
    "        assert isinstance(model_result.content, str)\n",
    "        return FinalResult(result=model_result.content)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running Mixture of Agents\n",
    "\n",
    "Let's run the mixture of agents on a math task. You can change the task to make it more challenging, for example, by trying tasks from the [International Mathematical Olympiad](https://www.imo-official.org/problems.aspx)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = (\n",
    "    \"I have 432 cookies, and divide them 3:4:2 between Alice, Bob, and Charlie. How many cookies does each person get?\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's set up the runtime with 3 layers of worker agents, each layer consisting of 3 worker agents.\n",
    "We only need to register a single worker agent types, \"worker\", because we are using\n",
    "the same model client configuration (i.e., gpt-4o-mini) for all worker agents.\n",
    "If you want to use different models, you will need to register multiple worker agent types,\n",
    "one for each model, and update the `worker_agent_types` list in the orchestrator agent's\n",
    "factory function.\n",
    "\n",
    "The instances of worker agents are automatically created when the orchestrator agent\n",
    "dispatches tasks to them.\n",
    "See [Agent Identity and Lifecycle](../core-concepts/agent-identity-and-lifecycle.md)\n",
    "for more information on agent lifecycle."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------------------------------------------------------\n",
      "Orchestrator-orchestrator:default:\n",
      "Received task: I have 432 cookies, and divide them 3:4:2 between Alice, Bob, and Charlie. How many cookies does each person get?\n",
      "--------------------------------------------------------------------------------\n",
      "Orchestrator-orchestrator:default:\n",
      "Dispatch to workers at layer 0\n",
      "--------------------------------------------------------------------------------\n",
      "Worker-worker:default/layer_0/worker_1:\n",
      "To divide 432 cookies in the ratio of 3:4:2 between Alice, Bob, and Charlie, you first need to determine the total number of parts in the ratio.\n",
      "\n",
      "Add the parts together:\n",
      "\\[ 3 + 4 + 2 = 9 \\]\n",
      "\n",
      "Now, you can find the value of one part by dividing the total number of cookies by the total number of parts:\n",
      "\\[ \\text{Value of one part} = \\frac{432}{9} = 48 \\]\n",
      "\n",
      "Now, multiply the value of one part by the number of parts for each person:\n",
      "\n",
      "- For Alice (3 parts):\n",
      "\\[ 3 \\times 48 = 144 \\]\n",
      "\n",
      "- For Bob (4 parts):\n",
      "\\[ 4 \\times 48 = 192 \\]\n",
      "\n",
      "- For Charlie (2 parts):\n",
      "\\[ 2 \\times 48 = 96 \\]\n",
      "\n",
      "Thus, the number of cookies each person gets is:\n",
      "- Alice: 144 cookies\n",
      "- Bob: 192 cookies\n",
      "- Charlie: 96 cookies\n",
      "--------------------------------------------------------------------------------\n",
      "Worker-worker:default/layer_0/worker_0:\n",
      "To divide 432 cookies in the ratio of 3:4:2 between Alice, Bob, and Charlie, we will first determine the total number of parts in the ratio:\n",
      "\n",
      "\\[\n",
      "3 + 4 + 2 = 9 \\text{ parts}\n",
      "\\]\n",
      "\n",
      "Next, we calculate the value of one part by dividing the total number of cookies by the total number of parts:\n",
      "\n",
      "\\[\n",
      "\\text{Value of one part} = \\frac{432}{9} = 48\n",
      "\\]\n",
      "\n",
      "Now, we can find out how many cookies each person receives by multiplying the value of one part by the number of parts each person receives:\n",
      "\n",
      "- For Alice (3 parts):\n",
      "\\[\n",
      "3 \\times 48 = 144 \\text{ cookies}\n",
      "\\]\n",
      "\n",
      "- For Bob (4 parts):\n",
      "\\[\n",
      "4 \\times 48 = 192 \\text{ cookies}\n",
      "\\]\n",
      "\n",
      "- For Charlie (2 parts):\n",
      "\\[\n",
      "2 \\times 48 = 96 \\text{ cookies}\n",
      "\\]\n",
      "\n",
      "Thus, the number of cookies each person gets is:\n",
      "- **Alice**: 144 cookies\n",
      "- **Bob**: 192 cookies\n",
      "- **Charlie**: 96 cookies\n",
      "--------------------------------------------------------------------------------\n",
      "Worker-worker:default/layer_0/worker_2:\n",
      "To divide the cookies in the ratio of 3:4:2, we first need to find the total parts in the ratio. \n",
      "\n",
      "The total parts are:\n",
      "- Alice: 3 parts\n",
      "- Bob: 4 parts\n",
      "- Charlie: 2 parts\n",
      "\n",
      "Adding these parts together gives:\n",
      "\\[ 3 + 4 + 2 = 9 \\text{ parts} \\]\n",
      "\n",
      "Next, we can determine how many cookies each part represents by dividing the total number of cookies by the total parts:\n",
      "\\[ \\text{Cookies per part} = \\frac{432 \\text{ cookies}}{9 \\text{ parts}} = 48 \\text{ cookies/part} \\]\n",
      "\n",
      "Now we can calculate the number of cookies for each person:\n",
      "- Alice's share: \n",
      "\\[ 3 \\text{ parts} \\times 48 \\text{ cookies/part} = 144 \\text{ cookies} \\]\n",
      "- Bob's share: \n",
      "\\[ 4 \\text{ parts} \\times 48 \\text{ cookies/part} = 192 \\text{ cookies} \\]\n",
      "- Charlie's share: \n",
      "\\[ 2 \\text{ parts} \\times 48 \\text{ cookies/part} = 96 \\text{ cookies} \\]\n",
      "\n",
      "So, the final distribution of cookies is:\n",
      "- Alice: 144 cookies\n",
      "- Bob: 192 cookies\n",
      "- Charlie: 96 cookies\n",
      "--------------------------------------------------------------------------------\n",
      "Orchestrator-orchestrator:default:\n",
      "Received results from workers at layer 0\n",
      "--------------------------------------------------------------------------------\n",
      "Orchestrator-orchestrator:default:\n",
      "Dispatch to workers at layer 1\n",
      "--------------------------------------------------------------------------------\n",
      "Worker-worker:default/layer_1/worker_2:\n",
      "To divide 432 cookies in the ratio of 3:4:2 among Alice, Bob, and Charlie, follow these steps:\n",
      "\n",
      "1. **Determine the total number of parts in the ratio**:\n",
      "   \\[\n",
      "   3 + 4 + 2 = 9 \\text{ parts}\n",
      "   \\]\n",
      "\n",
      "2. **Calculate the value of one part** by dividing the total number of cookies by the total number of parts:\n",
      "   \\[\n",
      "   \\text{Value of one part} = \\frac{432}{9} = 48\n",
      "   \\]\n",
      "\n",
      "3. **Calculate the number of cookies each person receives** by multiplying the value of one part by the number of parts each individual gets:\n",
      "   - **For Alice (3 parts)**:\n",
      "     \\[\n",
      "     3 \\times 48 = 144 \\text{ cookies}\n",
      "     \\]\n",
      "   - **For Bob (4 parts)**:\n",
      "     \\[\n",
      "     4 \\times 48 = 192 \\text{ cookies}\n",
      "     \\]\n",
      "   - **For Charlie (2 parts)**:\n",
      "     \\[\n",
      "     2 \\times 48 = 96 \\text{ cookies}\n",
      "     \\]\n",
      "\n",
      "Thus, the final distribution of cookies is:\n",
      "- **Alice**: 144 cookies\n",
      "- **Bob**: 192 cookies\n",
      "- **Charlie**: 96 cookies\n",
      "--------------------------------------------------------------------------------\n",
      "Worker-worker:default/layer_1/worker_0:\n",
      "To divide 432 cookies among Alice, Bob, and Charlie in the ratio of 3:4:2, we can follow these steps:\n",
      "\n",
      "1. **Calculate the Total Parts**: \n",
      "   Add the parts of the ratio together:\n",
      "   \\[\n",
      "   3 + 4 + 2 = 9 \\text{ parts}\n",
      "   \\]\n",
      "\n",
      "2. **Determine the Value of One Part**: \n",
      "   Divide the total number of cookies by the total number of parts:\n",
      "   \\[\n",
      "   \\text{Value of one part} = \\frac{432 \\text{ cookies}}{9 \\text{ parts}} = 48 \\text{ cookies/part}\n",
      "   \\]\n",
      "\n",
      "3. **Calculate Each Person's Share**:\n",
      "   - **Alice's Share** (3 parts):\n",
      "     \\[\n",
      "     3 \\times 48 = 144 \\text{ cookies}\n",
      "     \\]\n",
      "   - **Bob's Share** (4 parts):\n",
      "     \\[\n",
      "     4 \\times 48 = 192 \\text{ cookies}\n",
      "     \\]\n",
      "   - **Charlie's Share** (2 parts):\n",
      "     \\[\n",
      "     2 \\times 48 = 96 \\text{ cookies}\n",
      "     \\]\n",
      "\n",
      "4. **Final Distribution**:\n",
      "   - Alice: 144 cookies\n",
      "   - Bob: 192 cookies\n",
      "   - Charlie: 96 cookies\n",
      "\n",
      "Thus, the distribution of cookies is:\n",
      "- **Alice**: 144 cookies\n",
      "- **Bob**: 192 cookies\n",
      "- **Charlie**: 96 cookies\n",
      "--------------------------------------------------------------------------------\n",
      "Worker-worker:default/layer_1/worker_1:\n",
      "To divide 432 cookies among Alice, Bob, and Charlie in the ratio of 3:4:2, we first need to determine the total number of parts in this ratio.\n",
      "\n",
      "1. **Calculate Total Parts:**\n",
      "   \\[\n",
      "   3 \\text{ (Alice)} + 4 \\text{ (Bob)} + 2 \\text{ (Charlie)} = 9 \\text{ parts}\n",
      "   \\]\n",
      "\n",
      "2. **Determine the Value of One Part:**\n",
      "   Next, we'll find out how many cookies correspond to one part by dividing the total number of cookies by the total number of parts:\n",
      "   \\[\n",
      "   \\text{Value of one part} = \\frac{432 \\text{ cookies}}{9 \\text{ parts}} = 48 \\text{ cookies/part}\n",
      "   \\]\n",
      "\n",
      "3. **Calculate the Share for Each Person:**\n",
      "   - **Alice's Share (3 parts):**\n",
      "     \\[\n",
      "     3 \\times 48 = 144 \\text{ cookies}\n",
      "     \\]\n",
      "   - **Bob's Share (4 parts):**\n",
      "     \\[\n",
      "     4 \\times 48 = 192 \\text{ cookies}\n",
      "     \\]\n",
      "   - **Charlie’s Share (2 parts):**\n",
      "     \\[\n",
      "     2 \\times 48 = 96 \\text{ cookies}\n",
      "     \\]\n",
      "\n",
      "4. **Summary of the Distribution:**\n",
      "   - **Alice:** 144 cookies\n",
      "   - **Bob:** 192 cookies\n",
      "   - **Charlie:** 96 cookies\n",
      "\n",
      "In conclusion, Alice receives 144 cookies, Bob receives 192 cookies, and Charlie receives 96 cookies.\n",
      "--------------------------------------------------------------------------------\n",
      "Orchestrator-orchestrator:default:\n",
      "Received results from workers at layer 1\n",
      "--------------------------------------------------------------------------------\n",
      "Orchestrator-orchestrator:default:\n",
      "Performing final aggregation\n",
      "--------------------------------------------------------------------------------\n",
      "Final result:\n",
      "To divide 432 cookies among Alice, Bob, and Charlie in the ratio of 3:4:2, follow these steps:\n",
      "\n",
      "1. **Calculate the Total Parts in the Ratio:**\n",
      "   Add the parts of the ratio together:\n",
      "   \\[\n",
      "   3 + 4 + 2 = 9\n",
      "   \\]\n",
      "\n",
      "2. **Determine the Value of One Part:**\n",
      "   Divide the total number of cookies by the total number of parts:\n",
      "   \\[\n",
      "   \\text{Value of one part} = \\frac{432}{9} = 48 \\text{ cookies/part}\n",
      "   \\]\n",
      "\n",
      "3. **Calculate Each Person's Share:**\n",
      "   - **Alice's Share (3 parts):**\n",
      "     \\[\n",
      "     3 \\times 48 = 144 \\text{ cookies}\n",
      "     \\]\n",
      "   - **Bob's Share (4 parts):**\n",
      "     \\[\n",
      "     4 \\times 48 = 192 \\text{ cookies}\n",
      "     \\]\n",
      "   - **Charlie's Share (2 parts):**\n",
      "     \\[\n",
      "     2 \\times 48 = 96 \\text{ cookies}\n",
      "     \\]\n",
      "\n",
      "Therefore, the distribution of cookies is as follows:\n",
      "- **Alice:** 144 cookies\n",
      "- **Bob:** 192 cookies\n",
      "- **Charlie:** 96 cookies\n",
      "\n",
      "In summary, Alice gets 144 cookies, Bob gets 192 cookies, and Charlie gets 96 cookies.\n"
     ]
    }
   ],
   "source": [
    "runtime = SingleThreadedAgentRuntime()\n",
    "model_client = OpenAIChatCompletionClient(model=\"gpt-4o-mini\")\n",
    "await WorkerAgent.register(runtime, \"worker\", lambda: WorkerAgent(model_client=model_client))\n",
    "await OrchestratorAgent.register(\n",
    "    runtime,\n",
    "    \"orchestrator\",\n",
    "    lambda: OrchestratorAgent(model_client=model_client, worker_agent_types=[\"worker\"] * 3, num_layers=3),\n",
    ")\n",
    "\n",
    "runtime.start()\n",
    "result = await runtime.send_message(UserTask(task=task), AgentId(\"orchestrator\", \"default\"))\n",
    "\n",
    "await runtime.stop_when_idle()\n",
    "await model_client.close()\n",
    "\n",
    "print(f\"{'-'*80}\\nFinal result:\\n{result.result}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
